/**
* @file op_model_cache.cpp
*
* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/

#include "op_model_cache.h"

#include "framework/common/util.h"
#include "executor/ge_executor.h"

namespace {
    std::atomic<std::uint64_t> atomicId(0UL);
}

namespace acl {
aclError OpModelCache::GetOpModel(const OpModelDef &modelDef, OpModel &opModel)
{
    const auto &key = modelDef.modelPath;
    {
        const std::lock_guard<std::mutex> locker(mutex_);
        const auto iter = cachedModels_.find(key);
        if (iter != cachedModels_.end()) {
            opModel = iter->second;
            return ACL_SUCCESS;
        }
    }
    ACL_LOG_INNER_ERROR("[Get][OpModel]GetOpModel fail, modelPath = %s", modelDef.modelPath.c_str());
    return ACL_ERROR_FAILURE;
}

aclError OpModelCache::Add(const OpModelDef &modelDef, OpModel &opModel)
{
    ACL_LOG_INFO("start to execute OpModelCache::Add, modelPath = %s, key = %p", modelDef.modelPath.c_str(), &modelDef);
    const auto key = modelDef.modelPath;
    opModel.cacheKey = key;
    const std::lock_guard<std::mutex> locker(mutex_);
    opModel.opModelId = atomicId++;
    cachedModels_[key] = opModel;
    return ACL_SUCCESS;
}

aclError OpModelCache::Delete(const OpModelDef &modelDef, const bool isDynamic)
{
    ACL_LOG_INFO("start to execute OpModelCache::Delete, modelPath = %s", modelDef.modelPath.c_str());
    const auto key = modelDef.modelPath;
    const std::lock_guard<std::mutex> locker(mutex_);
    const auto it = cachedModels_.find(key);
    if (it != cachedModels_.end()) {
        const uint64_t opId = it->second.opModelId;
        (void)cachedModels_.erase(it);
        ACL_LOG_INFO("start to unload single op resource %lu", opId);
        if (isDynamic) {
            return static_cast<int32_t>(ge::GeExecutor::UnloadDynamicSingleOp(opId));
        } else {
            return static_cast<int32_t>(ge::GeExecutor::UnloadSingleOp(opId));
        }
    }
    return ACL_SUCCESS;
}

aclError OpModelCache::UpdateCachedExecutor(const std::string &modelPath,
                                            std::shared_ptr<gert::StreamExecutor> executor)
{
    const std::lock_guard<std::mutex> locker(mutex_);
    const auto &iter = cachedModels_.find(modelPath);
    if (iter == cachedModels_.end()) {
        ACL_LOG_INNER_ERROR("search model cache faild when update runtime v2 stream executor, key is %s",
                            modelPath.c_str());
        return ACL_ERROR_FAILURE;
    }
    iter->second.executor = std::move(executor);
    return ACL_SUCCESS;
}

aclError OpModelCache::CleanCachedExecutor(rtStream_t stream)
{
    const std::lock_guard<std::mutex> locker(mutex_);
    for (auto it = cachedModels_.begin(); it != cachedModels_.end(); ++it) {
        if (it->second.executor != nullptr) {
            it->second.executor->Erase(stream);
        }
    }
    return ACL_SUCCESS;
}

void OpModelCache::CleanCachedModels()
{
    cachedModels_.clear();
}
} // namespace acl
